package main

import (
	"context"
	"fmt"
	"gorm.io/gorm/logger"
	"log"
	"time"
)

// 日志格式: [info] message. trace:111
var logTpl = "[%s] %s. trace:%s"

type myLogger struct {
	logger.Config
}

func (m myLogger) LogMode(level logger.LogLevel) logger.Interface {
	newLogger := m
	newLogger.Config.LogLevel = level
	return newLogger
}

func (m myLogger) Info(ctx context.Context, s string, i ...interface{}) {
	if m.LogLevel >= logger.Info {
		log.Printf(logTpl, "INFO", fmt.Sprintf(s, i...), getTrace(ctx))
	}
}

func (m myLogger) Warn(ctx context.Context, s string, i ...interface{}) {
	if m.LogLevel >= logger.Warn {
		log.Printf(logTpl, "WARN", fmt.Sprintf(s, i...), getTrace(ctx))
	}
}

func (m myLogger) Error(ctx context.Context, s string, i ...interface{}) {
	if m.LogLevel >= logger.Error {
		log.Printf(logTpl, "ERROR", fmt.Sprintf(s, i...), getTrace(ctx))
	}
}

func (m myLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rows int64), err error) {
	cost := time.Since(begin)
	sql, rows := fc()
	// 打印error
	if err != nil && m.LogLevel >= logger.Error {
		m.Error(ctx, "sql:%s. rows:%d, cost:%v", sql, rows, cost)
		return
	}
	// 打印慢sql
	if m.Config.LogLevel >= logger.Warn && cost > m.Config.SlowThreshold {
		m.Warn(ctx, "slow sql:%s. rows:%d, cost:%v", sql, rows, cost)
		return
	}
	// 打印sql
	if m.Config.LogLevel >= logger.Info {
		m.Info(ctx, "sql:%s. rows:%d, cost:%v", sql, rows, cost)
	}
	return
}

// 获取trace方法,一般引用第三方库如"github.com/SkyAPM/go2sky"
func getTrace(ctx context.Context) string {
	return "helloworld"
}
